# encoding:utf-8
"""
@author wangyizhong, Zhilin Wang
@date 2021/09/03
@desc 自定义算子中的通用方法包
"""

from pyspark import SparkConf, SparkContext
from pyspark.sql import SQLContext, SparkSession
import psycopg2
import json
import requests
import urllib3
import os
import pymysql
import pandas as pd
import io
from common.config.config import *

os.environ['PYTHON_EGG_CACHE'] = '/tmp'
urllib3.disable_warnings()

try:
	import MySQLdb as mysql
except ImportError:
	print("")

try:
	import pymysql as mysql
except ImportError:
	print("")

pyFiles = ["hdfs:///datascience/python/psycopg2-2_8_1.zip"]

def get_feature_cols(dataframe):
	fc = dataframe.dtypes
	featureCols = ""
	cols = []
	for tup in fc:
		cols.append(tup[0])
	if "_record_id_" not in cols:
		print("请检查并确保数据集 df_ult 中包含 _record_id_ 列")
	for col in cols:
		featureCols += col + ","
	return featureCols


def load_df_spark(tableName):
	"""
		read dataset from gp
	"""
	if "s2P3A1r9K4" not in locals().keys():
		s2P3A1r9K4 = SparkSession.Builder().appName("demo-test2").getOrCreate()
	data = s2P3A1r9K4.read.jdbc(url=GP_URL, table=tableName, properties=GP_PROP)
	return data


def pidTableCheck(pipelineId, tableName):
	"""
		load dataset from gp
	"""
	url = "https://nebula-dev.zjvis.net/api/model/userTableCheck"

	params = {
		"pipelineId": pipelineId,
		"tableName": tableName
	}

	headers = {
		"Origin": "https://nebula-dev.zjvis.net",
		"Connection": "close",
		"Content-Type": "application/json;charset=UTF-8",
		"Referer": "https://nebula-dev.zjvis.net/project/1292/data",
		"Cookie": "rememberMe=KCygMFvZDrdzG/HZ+skCKMsL9EiHE5QPkuS1aeBjBV4aRyVPrn2ENtc2aiiUup8jxfYVYmAdGyd33WDd2Fu76A4WR7gU0bPKTJ1hw64GGK+Q4rWucLai89kBgYdlK2fPo9ekiGvZk2gss8zQDeIiaIeRcWbCUelMo6xcSVRdu9tFCUHodQ6pMD8mBC5+TIlQ4LDyi5xOZtMjCfKZDTtKhi9sDZn5duoB+mW8AoRdJ5evqwU0shq9IrkifNmrW7FspfrPCInxYuAlOmVd90+q1xpVdGTDstlF1k1x7bOcqKpg/TsC4P62yPG0hsmfKH4cXN7mtN+Qey0YQff4VblcY+aoV4GGqVvLsfFoaIsAKiOAQVVmgGfEYMcvf6exfxwILwy9rnRj+U+w5EhWryu2nw4eLVQk8ku7n6elMX124X5UgpZ45CLPhncER4hNuc5bFozkCGbcQixK0cIHbWUaR7dArS0Mjoc9mBuG8RDrLovefspQz9VKqsYd+b3jZCtfyt087tT1qdUj/RooZiph8N5qCQsTAh2r7W+gcqiNgn9EUWKk/m6Yb3wMzSSt+ElR5fcP5tOnrqNlKxglWHcNFEhmQb164CxymNB+Z5Y9kDbe/VkvIANgodILdc1zgclSH52zBxtEdZ/gFKnxUm5Aho6zhWBW1/ENslhYthQwXLb2jP8sdCY4cN/2wh1c/4k7H7Mtf3z4X20qJBsfClmhNEcHV8HcdrynaBGD/rauogcRqAQz4tPEROvhASVmLDake28jUbUKVklP4kj5NvDqwPWf59zP4XF7mTi2plR2HEqnXE810EJH9FRF/ZRMT03UGmAyZRbiV8ct5wjzE6lXTJzasOqCVOKBAf72AeKSWfsKRbv/6YUrHwVJ7ofYk8BTFAF7ucHwMnw3TgSpbtpAsH9cq55LkVBcDQALjpL1rHCeqTyyVhllE1NfJKu5oMaQpY8w+PPBoSdYptYNEJEhnIB6a9U9OUbsp1nvaohNtQcuPfLxugnTXZN3OuKEc0kAuJKNzPgZvcqfaxiYTjkQ4gkok/2PiheOHhntPhAn+QLo7nJlhG5EVWSjXc8ugLaJKakRTKtH0tNWY7OrMCob4WifYZbO1GXrdlJ5qt0ujA2N5d3JvhJN2XeWGhQvC0Rm4dSv8z8SL7LPxr6LU/fYEmX9hT+SdtVcLbgDT96RwXCGd445pIOAvYB2/iT8f1j+cbMns444lKCekI2QOy6VjGLU9FdmqXJ4CVWPfNPsSVJcp5Tz5NmjQSMOt1Ztcq1y54uBIYGLEe3sh6faccbl/0xfYEv0NOd/DxFrtfmfAnHitwRzRk/gQlRqTaMbc+qeVK1iwgPpfZJn6yscUjSEm6DmZ5SaZUCo+PZHCRDVrBfGDM47m0tZFh3bxyb4WqiNLpK3QJZjzl1bykP4qyhpPCHf8EX9vpLyWSBZf/8VulU="
	}
	requests.adapters.DEFAULT_RETRIES = 5
	session = requests.Session()
	res = session.post(url, data=json.dumps(params), headers=headers, verify=False)
	if res.json() != 1:
		print("UNIQUES33E45C654R765E87T24C75O856D976E26")
		return "UNIQUES33E45C654R765E87T24C75O856D976E26"
	else:
		if "s2P3A1r9K4" not in locals().keys():
			s2P3A1r9K4 = SparkSession.Builder().appName("demo-test2").getOrCreate()
		data = s2P3A1r9K4.read.jdbc(url=GP_URL, table=tableName, properties=GP_PROP)
		return data


def save_df_to_gp(dataframe, table_name):
	dataframe.write.jdbc(GP_URL, table_name, "overwrite", GP_PROP)


def get_saprk(appName=""):
	if appName == "":
		appName = "demo-test2"
	if "s2P3A1r9K4" in locals().keys():
		return s2P3A1r9K4
	else:
		s2P3A1r9K4 = SparkSession.Builder().appName(appName).getOrCreate()
		# spark.sparkContext.addPyFile(",".join(pyFiles))
		return s2P3A1r9K4


def getOrCreateSparkSession(appName=""):
	if appName == "":
		appName = "demo-test2"
	if 'spark' in locals().keys():
		return spark
	else:
		spark = SparkSession.Builder().appName(appName).getOrCreate()
		# spark.sparkContext.addPyFile(",".join(pyFiles))
		return spark


def readFromGreenPlum(spark, tableName):
	"""
		read dataset from gp
	"""
	data = spark.read.jdbc(url=GP_URL, table=tableName, properties=GP_PROP)
	return data


def readMysql(spark, tableName):
	"""
	read from mysql
	"""
	data = spark.read.jdbc(url=MYSQL_URL, table=tableName, properties=MYSQL_PROP)
	return data


def saveTableForGreenPlum(dataset, tableName):
	"""
	保存greenplum表
	"""

	conn = psycopg2.connect(dbname=GP_DBNAME,
	                        user=GP_PROP.get('user'),
	                        password=GP_PROP.get("password"),
	                        host=GP_HOST,
	                        port=GP_PORT)
	cursor = conn.cursor()

	try:
		sql = "drop table if exists {}".format(tableName)
		cursor.execute(sql)
	except:
		print("drop table failed")
	dataset.write.format("jdbc") \
		.option("driver", GP_PROP.get("driver")) \
		.option("url", GP_URL).option("user", GP_PROP.get("user")) \
		.option("password", GP_PROP.get("password")) \
		.option("dbtable", tableName) \
		.mode("overwrite") \
		.save()


def saveTableForMysql(dataset, tableName):
	"""
	保存mysql表
	"""
	dataset.write.format("jdbc") \
		.option("driver", MYSQL_PROP.get("driver")) \
		.option("url", MYSQL_URL).option("user", MYSQL_PROP.get("user")) \
		.option("password", MYSQL_PROP.get("password")) \
		.option("dbtable", tableName) \
		.mode("overwrite") \
		.save()


def getTableMetaFromDataset(dataset):
	return dataset.dtypes


def getTableMeta(tableName):
	conn = psycopg2.connect(host=GP_HOST, port=GP_PORT, dbname=GP_DBNAME, user=GP_PROP.get("user"),
	                        password=GP_PROP.get("password"))
	cursor = conn.cursor()
	tableName = tableName.split(".")[-1]
	cursor.execute(
		"select column_name, data_type from information_schema.columns where table_name = '{}'".format(tableName))
	return cursor.fetchall()


def saveMetaForMysql(meta, instanceId):
	db = None
	try:
		db = mysql.connect(host=MYSQL_HOST, user=MYSQL_USER, passwd=MYSQL_PASSWORD, db=MYSQL_DBNAME)
		cursor = db.cursor()
		sql = "update aiworks.task_instance set log_info = '{}' where id = {}".format(
			json.dumps(meta, ensure_ascii=False), str(instanceId))
		cursor.execute(sql)
		db.commit()
	except Exception as e:
		print(e)
		db.rollback()
		return False
	finally:
		if db != None:
			db.close()
	return True


def getResultMeta(outTables=[], inputMeatas={}, status=0, msg="success"):
	ret = {
		"status": status,
		"error_msg": msg,
		"result": {

		}

	}
	inputParams = {}

	for k, v in inputMeatas.items():
		inputParams[k] = v

	ret["result"]["input_params"] = inputParams
	outputParams = []
	for out in outTables:
		outRet = {}
		outRet["out_table_name"] = out
		meta = getTableMeta(out)
		cols = []
		types = []
		for item in meta:
			cols.append(item[0])
			types.append(item[1])
		outRet["output_cols"] = cols
		outRet["col_types"] = types
		outputParams.append(outRet)

	ret["result"]["output_params"] = outputParams

	return ret


def load_df_pandas(tableName):
	sql = "select * from " + tableName + " order by _record_id_"
	conn = psycopg2.connect(dbname=GP_DBNAME,
	                        user=GP_USER,
	                        password=GP_PASSWORD,
	                        host=GP_HOST,
	                        port=GP_PORT)
	df = pd.read_sql(sql, conn)
	conn.close()
	return df


def save_gp(df, table_name):
	conn = psycopg2.connect(dbname=GP_DBNAME,
	                        user=GP_USER,
	                        password=GP_PASSWORD,
	                        host=GP_HOST,
	                        port=GP_PORT)
	cursor = conn.cursor()

	try:
		sql = "drop table if exists {}".format(table_name)
		cursor.execute(sql)

		dtypes = []
		for dtype in df.dtypes:
			if dtype == float:
				dtypes.append("float4")
			elif dtype == int:
				dtypes.append("integer")
			else:
				dtypes.append("varchar")
		columns = [column + " " + dtype for column, dtype in zip(df.columns, dtypes)]
		sql = "create table {}({})".format(table_name, ",".join(columns))
		cursor.execute(sql)
		data_io = io.StringIO()
		df.to_csv(data_io, sep="|", index=False)
		data_io.seek(0)
		# data_io.readline()  # remove header DO NOT DELETE THIS COMMENT
		copy_cmd = "COPY %s FROM STDIN HEADER DELIMITER '|' CSV" % table_name
		cursor.copy_expert(copy_cmd, data_io)
	except Exception as e:
		print(e)

	conn.commit()
	cursor.close()
	conn.close()

# sourceTable = 'pipeline.view_tclean_23522_1631934018268k';
# # 2.将数据读取为spark数据集
# dataframe = pidTableCheck(3109,sourceTable)
# print(dataframe)
# print(pidTableCheck("3101","pipeline.view_tclean_23522_1631934018268"))
